# Kernel: hgemm_nn_128x128

# Copyright 2014 Nervana Systems Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#    http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


<CONSTANT_MAPPING>
    addr_zero  : 4x<128*8*4>

    gridDimA : c[0x0][0x14]
    gridDimB : c[0x0][0x18]

    param_Rand[0]   : c[0x0][0x140]
    param_Rand[1]   : c[0x0][0x144]
    param_A[0]      : c[0x0][0x148]
    param_A[1]      : c[0x0][0x14c]
    param_B[0]      : c[0x0][0x150]
    param_B[1]      : c[0x0][0x154]
    param_C[0]      : c[0x0][0x158]
    param_C[1]      : c[0x0][0x15c]
    param_lda       : c[0x0][0x160]
    param_ldb8      : c[0x0][0x164]
    param_ldc       : c[0x0][0x168]
    param_m         : c[0x0][0x16c]
    param_n         : c[0x0][0x170]
    param_k         : c[0x0][0x174]
    param_alpha     : c[0x0][0x178]
    param_beta      : c[0x0][0x17c]
    param_flags     : c[0x0][0x180]
    param_ldaz      : c[0x0][0x184]
    param_ldbz      : c[0x0][0x188]
    param_ldcz      : c[0x0][0x18c]
    param_loops     : c[0x0][0x190]
</CONSTANT_MAPPING>

<REGISTER_MAPPING>

    64-95   ~ tidAX, tidBX, lda, ldb, ldaz, ldbz, tid1, tid7, tid31, tid128, tbid, txa, xmad_ta, xmad_tb, dimA, flag, k<1-3>, x<1-3>

    0-63    : czero<00-63>

     3, 2,11,10,19,18,27,26 : cx<0-7>y0
     7, 6,15,14,23,22,31,30 : cx<0-7>y1
     1, 0, 9, 8,17,16,25,24 : cx<0-7>y2
     5, 4,13,12,21,20,29,28 : cx<0-7>y3
    35,34,43,42,51,50,59,58 : cx<0-7>y4
    39,38,47,46,55,54,63,62 : cx<0-7>y5
    33,32,41,40,49,48,57,56 : cx<0-7>y6
    37,36,45,44,53,52,61,60 : cx<0-7>y7

    64-79   : j0Ay<0-7>, j0Bx<0-7>
    80-95   : j1Ay<0-7>, j1Bx<0-7>

    96-105  : loadB<0-3>, loadA<0-5>

    106-109 : trackA<0-1>, trackB<0-1>

    110-118 ~ writeAs, writeBs, k, txb, tidAY, tidBY, ta, tb, loop
    119-125 ~ readAs, readBs, tid, blkA, blkB, blkZ, seed
    126-127 : Rand<0-1>

    64-75   ~ ldc, ldcz, ci, xmad_c, clk_shf1, clk_shf2, tid_31, tid_96, tid_128

    64-75   : c<0-7>, d3, d2, d1, d0
    76-85   : C00y<0-1>, C04y<0-1>, C08y<0-1>, C12y<0-1>
    86-107  ~ lfsr<0-2>, ldc1, ldc4, ldc60, writeCs, readCs, cx<00|64>, cy<00|04|08|12>, alpha, beta

    108-125 ~ exp<0-3>, rand<0-3>, lfsr<0-2>_1, lfsr<0-2>_2, flags


</REGISTER_MAPPING>

--:-:1:-:1      S2R tid,  SR_TID.X;
--:-:2:-:1      S2R blkA, SR_CTAID.Y;
--:-:3:-:1      S2R blkB, SR_CTAID.Z;
--:-:4:-:1      S2R blkZ, SR_CTAID.X;

<SCHEDULE_BLOCK>
--:-:-:-:1      MOV k,   param_k;
--:-:-:-:1      MOV lda, param_lda;
--:-:-:-:1      MOV ldb, param_ldb8;
--:-:-:-:1      SHR.U32 ldb, ldb, 4;
--:-:-:-:1      MOV ldaz, param_ldaz;
--:-:-:-:1      MOV ldbz, param_ldbz;
--:-:-:-:1      MOV loop, RZ;

--:-:-:-:1      STS.128 [addr_zero], RZ;
<CODE>
        join('', map sprintf("--:-:-:-:1      LDS.U.128 czero%02d, [addr_zero];\n", $_ * 4), 0..15);
</CODE>

// Grab a seed for this thread
// (blkB*gridDimA*256 + blkA*256 + tid) & (1024*256 - 1)
--:-:-:-:1      MOV flag, param_flags;
--:-:-:-:1      LOP.AND.NZ P4, RZ, flag, 0x1;
--:-:-:-:1      MOV dimA, gridDimA;
03:-:-:-:1      ISCADD tbid, blkA, tid, 8;
04:-:-:-:1      XMAD.U16.U16 dimA, blkB, dimA, RZ;
--:-:-:-:1      ISCADD tbid, dimA, tbid, 8;
--:-:-:-:1      LOP.AND seed, tbid, 1x<2048*32 - 1>;
--:-:-:-:1      LEA      Rand0.CC, seed, param_Rand[0],     0x2;
--:-:-:-:1      LEA.HI.X Rand1,    seed, param_Rand[1], RZ, 0x2;
--:-:-:-:1  @P4 LDG.E.CS seed, [Rand];

01:-:-:-:1      LOP.AND tid31,  tid,  31;
--:-:-:-:1      LOP.AND tid128, tid,  128;

// tidAY  = (tid & 1) << 2
--:-:-:-:1      LOP.AND tid1,  tid,  1;
--:-:-:-:1      SHL     tidAY, tid1, 2;

// tidAX = tid >> 1
--:-:-:-:1      SHR.U32 tidAX, tid,   1;

// trackA += 2 * ((blkA*128 + tidAX) * lda + tidAY)
02:-:-:-:1      ISCADD   txa, blkA, tidAX, 7;
--:-:-:-:1      XMAD.LO  ta,  lda,  txa,   tidAY, xmad_ta;
08:-:-:-:1      XMAD.LO2 ta,  ldaz, blkZ,  ta;
--:-:-:-:1      LEA      trackA0.CC, ta, param_A[0],     0x1;
--:-:-:-:1      LEA.HI.X trackA1,    ta, param_A[1], RZ, 0x1;

--:-:-:-:1      ISETP.LT.AND P5, PT, txa, param_m, PT;

// tidBX = (tid & 31) << 2
// tidBY = (tid >> 5) & 7
--:-:-:-:1      SHL     tidBX, tid31, 2;
--:-:-:-:1      BFE.U32 tidBY, tid,  0x305; // 3 bits at position 5

// trackB += (blkB*128 + ldb*tidBY + tidBX) * 2
04:-:-:-:1      ISCADD   txb, blkB, tidBX, 7;
--:-:-:-:1      XMAD.LO2 tb,  ldb,  tidBY, txb;
08:-:-:-:1      XMAD.LO2 tb,  ldbz, blkZ,  tb;
--:-:-:-:1      LEA      trackB0.CC, tb, param_B[0],     0x1;
--:-:-:-:1      LEA.HI.X trackB1,    tb, param_B[1], RZ, 0x1;

--:-:-:-:1      ISETP.LT.AND P6, PT, txb, param_n, PT;

// writeAs = 4 * (128 * tidAY + tidAX)
--:-:-:-:1      ISCADD  writeAs, tidAY, tidAX, 7;
--:-:-:-:1      ISCADD  writeAs, writeAs, 4x<128*8*2>, 2;


// writeBs = (128*tidBY + tidBX) * 4
--:-:-:-:1      ISCADD  writeBs, tidBY, tidBX, 7;
--:-:-:-:1      ISCADD  writeBs, writeBs, 4x<128*8*3>, 2;

// readAs  = (((tid & 0x70) >> 3) | (tid & 1)) << 4
--:-:-:-:1      LOP.AND readAs, tid,    0x70;
--:-:-:-:1      SHR.U32 readAs, readAs, 3;
--:-:-:-:1      LOP.OR  readAs, readAs, tid1;
--:-:-:-:1      SHL     readAs, readAs, 4;

// readBs = ((tid128 >> 4) | ((tid >> 1) & 7)) << 4 + 4096;
--:-:-:-:1      BFE.U32 tid7,   tid,    0x301; // 3 bits at position 1
--:-:-:-:1      SHR.U32 readBs, tid128, 4;
--:-:-:-:1      LOP.OR  readBs, readBs, tid7;
--:-:-:-:1      ISCADD  readBs, readBs, 4x<128*8>, 4;
</SCHEDULE_BLOCK>

REMAINDER:

<CODE>
    our $vec;
    return $vec ? q{
--:-:-:-:2      ISETP.LT.AND P3, PT, tidBY, k, P6;
--:-:-:Y:b      ISETP.LT.AND P2, PT, tidAY, k, P5;

--:-:4:-:2  @P3 LDG.E.CI.64 loadB0, [trackB];
--:-:2:-:1  @P2 LDG.E.CI.64 loadA0, [trackA + 2x<0>];
--:-:2:-:1  @P2 LDG.E.CI.64 loadA4, [trackA + 2x<8>];

--:-:-:-:0      PSETP.AND.AND P4, PT, PT, PT, PT;

--:-:5:-:1 @!P3 LDS.U.64 loadB0, [addr_zero];
--:-:6:-:1 @!P2 LDS.U.64 loadA0, [addr_zero];
--:-:6:-:1 @!P2 LDS.U.64 loadA4, [addr_zero];
    } : q{

<SCHEDULE_BLOCK>
// doLoad0 = tidBY < k
--:-:-:-:1      IADD x1, txb, 1;
--:-:-:-:1      IADD x2, txb, 2;
--:-:-:-:1      IADD x3, txb, 3;

--:-:-:-:1      ISETP.LT.AND P0, PT, tidBY, k, P6;
--:-:-:-:1      ISETP.LT.AND P1, PT, x1, param_n, P0;
--:-:-:-:1      ISETP.LT.AND P2, PT, x2, param_n, P0;
--:-:-:-:1      ISETP.LT.AND P3, PT, x3, param_n, P0;

--:-:4:-:1  @P0 LDG.E.CI.S16 loadB0, [trackB + 2x<00 + 0>];
--:-:4:-:1  @P1 LDG.E.CI.S16 loadB1, [trackB + 2x<00 + 1>];
--:-:4:-:1  @P2 LDG.E.CI.S16 loadB2, [trackB + 2x<00 + 2>];
--:-:4:-:1  @P3 LDG.E.CI.S16 loadB3, [trackB + 2x<00 + 3>];

--:-:-:-:1 @!P0 MOV loadB0, RZ;
--:-:-:-:1 @!P1 MOV loadB1, RZ;
--:-:-:-:1 @!P2 MOV loadB2, RZ;
--:-:-:-:1 @!P3 MOV loadB3, RZ;


--:-:-:-:1      IADD k1, tidAY, 1;
--:-:-:-:1      IADD k2, tidAY, 2;
--:-:-:-:1      IADD k3, tidAY, 3;

--:-:-:-:1      ISETP.LT.AND P0, PT, tidAY, k, P5;
--:-:-:-:1      ISETP.LT.AND P1, PT, k1, k, P5;
--:-:-:-:1      ISETP.LT.AND P2, PT, k2, k, P5;
--:-:-:-:1      ISETP.LT.AND P3, PT, k3, k, P5;

--:-:2:-:1  @P0 LDG.E.CI.S16 loadA0, [trackA + 2x<0>];
--:-:2:-:1  @P1 LDG.E.CI.S16 loadA1, [trackA + 2x<1>];
--:-:2:-:1  @P2 LDG.E.CI.S16 loadA2, [trackA + 2x<2>];
--:-:2:-:1  @P3 LDG.E.CI.S16 loadA3, [trackA + 2x<3>];

--:-:-:-:1 @!P0 MOV loadA0, RZ;
--:-:-:-:1 @!P1 MOV loadA1, RZ;
--:-:-:-:1 @!P2 MOV loadA2, RZ;
--:-:-:-:1 @!P3 MOV loadA3, RZ;
</SCHEDULE_BLOCK>
    };
</CODE>

<CODE>
    our $vec;
    return $vec ? q{
// bDoRemainder = k & 7 && k > 8
--:-:-:-:0      LOP.AND.NZ P1, RZ, k, 7;

18:-:-:-:4      F2F.F32.F16 loadB3, loadB1.H1;
--:-:-:-:0      IADD   trackB0.CC, trackB0, param_ldb8;
--:-:-:-:4      F2F.F32.F16 loadB2, loadB1.H0;
--:-:-:-:4      F2F.F32.F16 loadB1, loadB0.H1;
--:-:4:-:2      F2F.F32.F16 loadB0, loadB0.H0;

--:-:-:-:0      IADD.X trackB1, trackB1, RZ;

08:-:-:-:1      STS.128 [writeBs], loadB0;

22:-:-:-:4      F2F.F32.F16 loadA3, loadA1.H1;
--:-:-:-:0      IADD   trackA0.CC, trackA0, 2x<16>;
--:-:2:-:4      F2F.F32.F16 loadA2, loadA1.H0;
--:-:-:-:4      F2F.F32.F16 loadA1, loadA0.H1;
--:-:-:-:0      ISETP.GT.AND P1, PT, k, 8, P1;
--:-:3:-:1      F2F.F32.F16 loadA0, loadA0.H0;

--:-:-:-:0      IADD.X trackA1, trackA1, RZ;

02:-:-:-:1      STS [writeAs + 4x<3*128>], loadA3;
--:-:-:-:1      STS [writeAs + 4x<2*128>], loadA2;
04:-:-:-:1      STS [writeAs + 4x<1*128>], loadA1;
--:-:-:-:1      STS [writeAs + 4x<0*128>], loadA0;
    } : q{
--:-:-:-:0      ISETP.GT.AND P1, PT, k, 8, PT;

08:-:-:-:4      F2F.F32.F16 loadB0, loadB0;
--:-:-:-:0      IADD   trackB0.CC, trackB0, param_ldb8;
--:-:-:-:4      F2F.F32.F16 loadB1, loadB1;
--:-:-:-:4      F2F.F32.F16 loadB2, loadB2;
--:-:4:-:2      F2F.F32.F16 loadB3, loadB3;

--:-:-:-:0      IADD.X trackB1, trackB1, RZ;

08:-:-:-:1      STS.128 [writeBs], loadB0;

02:-:-:-:4      F2F.F32.F16 loadA0, loadA0;
--:-:-:-:0      IADD   trackA0.CC, trackA0, 2x<8>;
--:-:2:-:4      F2F.F32.F16 loadA1, loadA1;
--:-:-:-:4      F2F.F32.F16 loadA2, loadA2;
--:-:3:-:1      F2F.F32.F16 loadA3, loadA3;

--:-:-:-:0      IADD.X trackA1, trackA1, RZ;

02:-:-:-:1      STS [writeAs + 4x<0*128>], loadA0;
--:-:-:-:1      STS [writeAs + 4x<1*128>], loadA1;
04:-:-:-:1      STS [writeAs + 4x<2*128>], loadA2;
--:-:-:-:1      STS [writeAs + 4x<3*128>], loadA3;
    };
</CODE>

--:-:-:-:1      LOP.XOR readAs, readAs, 4x<128*8*2>;
--:-:-:-:0      LOP.XOR readBs, readBs, 4x<128*8*2>;
01:-:-:-:5      BAR.SYNC 0;
--:-:-:-:1      LOP.XOR writeAs, writeAs, 4x<128*8*2>;
--:-:-:-:0      LOP.XOR writeBs, writeBs, 4x<128*8*2>;



<CODE>
    our $vec;
    my $k_end = $vec ? 16 : 24;
    our @top = ("--:-:-:-:1      ISETP.GE.AND P3, PT, k, $k_end, P6;\n");
    our %insert =
    (
        ($vec ? 
            (
        j0c1  => "--:-:-:-:1      PSETP.AND.AND P4, PT, !P4, PT, PT;\n",
        j0c3  => "--:-:-:-:1      ISETP.GE.AND  P0, PT, k, $k_end, PT;\n",
        j0c15 => "--:-:-:-:1      PSETP.AND.AND P2, PT, P0, P4, P5;\n",

        j0c10 => "--:-:2:-:1  \@P3 LDG.E.CI.64 loadB0, [trackB];\n",
        
        j0c28 => "--:-:5:-:1  \@P2 LDG.E.CI.64 loadA0, [trackA + 2x<0>];\n",
        j0c30 => "20:4:6:-:1  \@P2 LDG.E.CI.64 loadA4, [trackA + 2x<8>];\n",

        j4c5  => "--:-:-:-:1 \@!P4 F2F.F32.F16 loadA3, loadA5.H1;\n",
        j4c9  => "--:-:-:-:1 \@!P4 F2F.F32.F16 loadA2, loadA5.H0;\n",
        j4c13 => "--:-:-:-:1 \@!P4 F2F.F32.F16 loadA1, loadA4.H1;\n",
        j4c17 => "--:-:-:-:1 \@!P4 F2F.F32.F16 loadA0, loadA4.H0;\n",

        j5c5  => "02:-:-:-:1  \@P0 F2F.F32.F16 loadB3, loadB1.H1;\n",
        j5c9  => "--:-:-:-:1  \@P0 F2F.F32.F16 loadB2, loadB1.H0;\n",
        j5c13 => "--:-:-:-:1  \@P0 F2F.F32.F16 loadB1, loadB0.H1;\n",
        j5c17 => "--:-:2:-:1  \@P0 F2F.F32.F16 loadB0, loadB0.H0;\n",

        j5c35 => "02:-:-:-:1  \@P0 STS.128 [writeBs], loadB0;\n",

        j6c5  => "10:-:2:-:1  \@P4 F2F.F32.F16 loadA3, loadA1.H1;\n",
        j6c9  => "--:-:3:-:1  \@P4 F2F.F32.F16 loadA2, loadA1.H0;\n",
        j6c13 => "--:-:4:-:1  \@P4 F2F.F32.F16 loadA1, loadA0.H1;\n",
        j6c17 => "--:-:5:-:1  \@P4 F2F.F32.F16 loadA0, loadA0.H0;\n",

        j6c29 => "02:-:-:-:1  \@P0 STS [writeAs + 4x<3*128>], loadA3;\n",
        j6c31 => "04:-:-:-:1  \@P0 STS [writeAs + 4x<2*128>], loadA2;\n",
        j6c33 => "08:-:-:-:1  \@P0 STS [writeAs + 4x<1*128>], loadA1;\n",
        j6c35 => "10:-:-:-:1  \@P0 STS [writeAs + 4x<0*128>], loadA0;\n",

        j6c11 => "08:-:-:-:1  \@P4 IADD   trackA0.CC, trackA0, 2x<16>;\n",
        j6c54 => "--:-:-:-:1  \@P4 IADD.X trackA1,    trackA1, RZ;\n",
            ) :
            (
        j0c1  => "--:-:-:-:1      ISETP.GE.AND P2, PT, k, $k_end, P5;\n",
        j0c3  => "--:-:-:-:1      ISETP.GE.AND P0, PT, k, $k_end, PT;\n",

        j0c10 => "--:-:2:-:1  \@P3 LDG.E.CI.S16 loadB0, [trackB + 2x<0>];\n",
        j0c12 => "--:-:2:-:1  \@P3 LDG.E.CI.S16 loadB1, [trackB + 2x<1>];\n",
        j0c14 => "--:-:2:-:1  \@P3 LDG.E.CI.S16 loadB2, [trackB + 2x<2>];\n",
        j0c16 => "--:-:2:-:1  \@P3 LDG.E.CI.S16 loadB3, [trackB + 2x<3>];\n",

        j0c29 => "--:-:6:-:1  \@P2 LDG.E.CI.S16 loadA0, [trackA + 2x<0>];\n",
        j0c31 => "--:-:6:-:1  \@P2 LDG.E.CI.S16 loadA1, [trackA + 2x<1>];\n",
        j0c33 => "--:-:6:-:1  \@P2 LDG.E.CI.S16 loadA2, [trackA + 2x<2>];\n",
        j0c35 => "--:-:6:-:1  \@P2 LDG.E.CI.S16 loadA3, [trackA + 2x<3>];\n",

        j5c8  => "02:-:-:-:1  \@P3 F2F.F32.F16 loadB0, loadB0;\n",
        j5c12 => "--:-:-:-:1  \@P3 F2F.F32.F16 loadB1, loadB1;\n",
        j5c16 => "--:-:-:-:1  \@P3 F2F.F32.F16 loadB2, loadB2;\n",
        j5c20 => "--:-:2:-:1  \@P3 F2F.F32.F16 loadB3, loadB3;\n",

        j5c39 => "02:-:-:-:1  \@P0 STS.128 [writeBs], loadB0;\n",

        j6c5  => "20:-:2:-:1  \@P2 F2F.F32.F16 loadA0, loadA0;\n",
        j6c9  => "--:-:3:-:1  \@P2 F2F.F32.F16 loadA1, loadA1;\n",
        j6c13 => "--:-:4:-:1  \@P2 F2F.F32.F16 loadA2, loadA2;\n",
        j6c17 => "--:-:5:-:1  \@P2 F2F.F32.F16 loadA3, loadA3;\n",

        j6c29 => "02:-:-:-:1  \@P0 STS [writeAs + 4x<0*128>], loadA0;\n",
        j6c31 => "04:-:-:-:1  \@P0 STS [writeAs + 4x<1*128>], loadA1;\n",
        j6c33 => "08:-:-:-:1  \@P0 STS [writeAs + 4x<2*128>], loadA2;\n",
        j6c35 => "10:-:-:-:1  \@P0 STS [writeAs + 4x<3*128>], loadA3;\n",

        j6c46 => "--:-:-:-:1  \@P2 IADD   trackA0.CC, trackA0, 2x<8>;\n",
        j6c54 => "--:-:-:-:1  \@P2 IADD.X trackA1,    trackA1, RZ;\n",
            )
        ),

        j5c46 => "--:-:-:-:1  \@P0 IADD   trackB0.CC, trackB0, param_ldb8;\n",
        j5c54 => "--:-:-:-:1  \@P0 IADD.X trackB1,    trackB1, RZ;\n",

        j6c63 => "--:-:-:-:0      IADD32I k, k, -8;\n" .
                 "--:-:-:-:5  \@P0 BAR.SYNC 0;\n" .
                 "--:-:-:-:1  \@P0 LOP.XOR readAs, readAs, 4x<128*8*2>;\n" .
                 "--:-:-:-:1  \@P0 LOP.XOR readBs, readBs, 4x<128*8*2>;\n" .
                 "--:-:-:-:1  \@P0 LOP.XOR writeAs, writeAs, 4x<128*8*2>;\n" .
                 "--:-:-:-:1  \@P0 LOP.XOR writeBs, writeBs, 4x<128*8*2>;\n",

        j7c63 => "--:-:-:Y:5  \@P0 BRA.U LOOP;\n" .
                 "--:-:-:Y:5  \@P1 BRA.U REMAINDER;\n",
    );
    return;
</CODE>

<INCLUDE file="nervanagpu/kernels/sass/hgemm_common_128x128.sass"/>
